#!/usr/bin/env guile
# -*- scheme -*-
!#

(use-modules (ice-9 match)
             (fibers)
             (fibers channels))

(define (make-squarer in)
  (let ((out (make-channel)))
    (spawn-fiber (lambda ()
                   (let lp ()
                     (let ((x (get-message in)))
                       (put-message out (* x x))
                       (lp))))
                 #:parallel? #t)
    out))

(define (make-sqrter in)
  (let ((out (make-channel)))
    (spawn-fiber (lambda ()
                   (let lp ()
                     (let ((x (get-message in)))
                       (put-message out (sqrt x))
                       (lp))))
                 #:parallel? #t)
    out))

(define (make-broadcaster in dimensions)
  (let ((out (map (lambda (_) (make-channel))
                  (iota dimensions))))
    (spawn-fiber (lambda ()
                   (let lp ()
                     (let ((x (get-message in)))
                       (for-each (lambda (ch) (put-message ch x))
                                 out)
                       (lp))))
                 #:parallel? #t)
    out))

(define (make-summer in)
  (let ((out (make-channel)))
    (spawn-fiber (lambda ()
                   (let lp ()
                     (let lp ((sum 0) (in in))
                       (match in
                         (() (put-message out sum))
                         ((ch . in) (lp (+ sum (get-message ch)) in))))
                     (lp)))
                 #:parallel? #t)
    out))

(define (make-counter)
  (let ((out (make-channel)))
    (spawn-fiber (lambda ()
                   (let lp ((n 0))
                     (put-message out n)
                     (lp (1+ n))))
                 #:parallel? #t)
    out))

(define (make-diagonal dimensions make-head make-tail)
  (let ((ch (make-head)))
    (let lp ((dimensions dimensions))
      (when (positive? dimensions)
        (make-tail ch)
        (lp (1- dimensions))))))

(define (test dimensions message-count)
  (let* ((ints (make-counter))
         (dims (make-broadcaster ints dimensions))
         (squares (map make-squarer dims))
         (sums (make-summer squares))
         (lens (make-sqrter sums)))
    (let lp ((n 0))
      (when (< n message-count)
        (get-message lens)
        (lp (1+ n))))))

(define (main args)
  (match args
    ((_ dimensions message-count)
     (let ((dimensions (string->number dimensions))
           (message-count (string->number message-count)))
       (run-fibers (lambda () (test dimensions message-count)))))))

(when (batch-mode?) (main (program-arguments)))
